from typing import Any

import pytest
from inline_snapshot import snapshot
from mcp.types import TextContent

from fastmcp import FastMCP
from fastmcp.contrib.bulk_tool_caller.bulk_tool_caller import (
    BulkToolCaller,
    CallToolRequest,
    CallToolRequestResult,
)
from fastmcp.tools.tool import Tool


class ToolException(Exception):
    """Custom exception for tool errors."""

    pass


async def error_tool(arg1: str) -> dict[str, Any]:
    """A tool that raises an error for testing purposes."""
    raise ToolException(f"Error in tool with arg1: {arg1}")


def error_tool_result_factory(arg1: str) -> CallToolRequestResult:
    """Generates the expected error result for error_tool."""
    # Mimic the error message format generated by BulkToolCaller when catching ToolException
    formatted_error_text = (
        "Error calling tool 'error_tool': Error in tool with arg1: " + arg1
    )
    return CallToolRequestResult(
        isError=True,
        content=[TextContent(text=formatted_error_text, type="text")],
        tool="error_tool",
        arguments={"arg1": arg1},
    )


async def echo_tool(arg1: str) -> str:
    """A simple tool that echoes arguments or raises an error."""
    return arg1


def echo_tool_result_factory(arg1: str) -> CallToolRequestResult:
    """A tool that returns a result based on the input arguments."""
    return CallToolRequestResult(
        isError=False,
        content=[TextContent(text=f"{arg1}", type="text")],
        tool="echo_tool",
        arguments={"arg1": arg1},
    )


async def no_return_tool(arg1: str) -> None:
    """A simple tool that echoes arguments or raises an error."""


def no_return_tool_result_factory(arg1: str) -> CallToolRequestResult:
    """A tool that returns a result based on the input arguments."""
    return CallToolRequestResult(
        isError=False,
        content=[],
        tool="no_return_tool",
        arguments={"arg1": arg1},
    )


@pytest.fixture
def live_server_with_tool() -> FastMCP:
    """Fixture to create a FastMCP server instance with the echo_tool registered."""
    server = FastMCP()
    server.add_tool(Tool.from_function(echo_tool))
    server.add_tool(Tool.from_function(error_tool))
    server.add_tool(Tool.from_function(no_return_tool))
    return server


@pytest.fixture
def bulk_caller_live(live_server_with_tool: FastMCP) -> BulkToolCaller:
    """Fixture to create a BulkToolCaller instance connected to the live server."""
    bulk_tool_caller = BulkToolCaller()
    bulk_tool_caller.register_tools(live_server_with_tool)
    return bulk_tool_caller


ECHO_TOOL_NAME = "echo_tool"
ERROR_TOOL_NAME = "error_tool"
NO_RETURN_TOOL_NAME = "no_return_tool"


async def test_call_tool_bulk_single_success(bulk_caller_live: BulkToolCaller):
    """Test single successful call via call_tool_bulk using echo_tool."""

    results = await bulk_caller_live.call_tool_bulk(
        ECHO_TOOL_NAME, [{"arg1": "value1"}]
    )

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[TextContent(type="text", text="value1")],
                tool="echo_tool",
                arguments={"arg1": "value1"},
            )
        ]
    )


async def test_call_tool_bulk_multiple_success(bulk_caller_live: BulkToolCaller):
    """Test multiple successful calls via call_tool_bulk using echo_tool."""
    results = await bulk_caller_live.call_tool_bulk(
        ECHO_TOOL_NAME, [{"arg1": "value1"}, {"arg1": "value2"}]
    )

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[TextContent(type="text", text="value1")],
                tool="echo_tool",
                arguments={"arg1": "value1"},
            ),
            CallToolRequestResult(
                content=[TextContent(type="text", text="value2")],
                tool="echo_tool",
                arguments={"arg1": "value2"},
            ),
        ]
    )


async def test_call_tool_bulk_error_stops(bulk_caller_live: BulkToolCaller):
    """Test call_tool_bulk stops on first error using error_tool."""
    results = await bulk_caller_live.call_tool_bulk(
        ERROR_TOOL_NAME,
        [{"arg1": "error_value"}, {"arg1": "value2"}],
        continue_on_error=False,
    )

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[
                    TextContent(
                        type="text",
                        text="Error calling tool 'error_tool': Error in tool with arg1: error_value",
                    )
                ],
                isError=True,
                tool="error_tool",
                arguments={"arg1": "error_value"},
            )
        ]
    )


async def test_call_tool_bulk_error_continues(bulk_caller_live: BulkToolCaller):
    """Test call_tool_bulk continues on error using error_tool and echo_tool."""

    tool_calls = [
        CallToolRequest(tool=ERROR_TOOL_NAME, arguments={"arg1": "error_value"}),
        CallToolRequest(tool=ECHO_TOOL_NAME, arguments={"arg1": "success_value"}),
    ]

    results = await bulk_caller_live.call_tools_bulk(tool_calls, continue_on_error=True)

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[
                    TextContent(
                        type="text",
                        text="Error calling tool 'error_tool': Error in tool with arg1: error_value",
                    )
                ],
                isError=True,
                tool="error_tool",
                arguments={"arg1": "error_value"},
            ),
            CallToolRequestResult(
                content=[TextContent(type="text", text="success_value")],
                tool="echo_tool",
                arguments={"arg1": "success_value"},
            ),
        ]
    )


async def test_call_tools_bulk_single_success(bulk_caller_live: BulkToolCaller):
    """Test single successful call via call_tools_bulk using echo_tool."""
    tool_calls = [CallToolRequest(tool=ECHO_TOOL_NAME, arguments={"arg1": "value1"})]

    results = await bulk_caller_live.call_tools_bulk(tool_calls)

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[TextContent(type="text", text="value1")],
                tool="echo_tool",
                arguments={"arg1": "value1"},
            )
        ]
    )


async def test_call_tools_bulk_multiple_success(bulk_caller_live: BulkToolCaller):
    """Test multiple successful calls via call_tools_bulk with different tools."""
    tool_calls = [
        CallToolRequest(tool=ECHO_TOOL_NAME, arguments={"arg1": "echo_value"}),
        CallToolRequest(
            tool=NO_RETURN_TOOL_NAME, arguments={"arg1": "no_return_value"}
        ),
    ]

    results = await bulk_caller_live.call_tools_bulk(tool_calls)

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[TextContent(type="text", text="echo_value")],
                tool="echo_tool",
                arguments={"arg1": "echo_value"},
            ),
            CallToolRequestResult(
                content=[], tool="no_return_tool", arguments={"arg1": "no_return_value"}
            ),
        ]
    )


async def test_call_tools_bulk_error_stops(bulk_caller_live: BulkToolCaller):
    """Test call_tools_bulk stops on first error using error_tool."""
    tool_calls = [
        CallToolRequest(tool=ERROR_TOOL_NAME, arguments={"arg1": "error_value"}),
        CallToolRequest(tool=ECHO_TOOL_NAME, arguments={"arg1": "skipped_value"}),
    ]

    results = await bulk_caller_live.call_tools_bulk(
        tool_calls, continue_on_error=False
    )

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[
                    TextContent(
                        type="text",
                        text="Error calling tool 'error_tool': Error in tool with arg1: error_value",
                    )
                ],
                isError=True,
                tool="error_tool",
                arguments={"arg1": "error_value"},
            )
        ]
    )


async def test_call_tools_bulk_error_continues(bulk_caller_live: BulkToolCaller):
    """Test call_tools_bulk continues on error using error_tool and echo_tool."""
    tool_calls = [
        CallToolRequest(tool=ERROR_TOOL_NAME, arguments={"arg1": "error_value"}),
        CallToolRequest(tool=ECHO_TOOL_NAME, arguments={"arg1": "success_value"}),
    ]

    results = await bulk_caller_live.call_tools_bulk(tool_calls, continue_on_error=True)

    assert results == snapshot(
        [
            CallToolRequestResult(
                content=[
                    TextContent(
                        type="text",
                        text="Error calling tool 'error_tool': Error in tool with arg1: error_value",
                    )
                ],
                isError=True,
                tool="error_tool",
                arguments={"arg1": "error_value"},
            ),
            CallToolRequestResult(
                content=[TextContent(type="text", text="success_value")],
                tool="echo_tool",
                arguments={"arg1": "success_value"},
            ),
        ]
    )
